from collections import defaultdict
from functools import lru_cache
from typing import Tuple, Dict, Any, List

import numpy as np

from centralized_verification.shields.shield import Shield, T, ShieldResult, AgentResult, AgentUpdate
from centralized_verification.shields.slugs_shielding.combine_identical_states import DecentralizedShieldSpec
from centralized_verification.shields.slugs_shielding.label_extractor import LabelExtractor


class SlugsDecentralizedShield(Shield[T, int]):
    """
    The shield state which we use for the decentralized shield is actually the _previous_ shield state
    (agents can't know the current shield state, because that would mean knowledge of the joint action)
    """

    def __init__(self, env: T, shield_spec: DecentralizedShieldSpec, label_extractor: LabelExtractor,
                 random_agent_order: bool = True, **kwargs):
        super().__init__(env, **kwargs)
        self.shield_spec = shield_spec
        self.label_extractor = label_extractor
        self.random_agent_order = random_agent_order

    def get_initial_shield_state(self, state, initial_joint_obs) -> T:
        return -1  # Special initial state

    @lru_cache()
    def _get_label_to_initial_shield_state_dict(self) -> Dict[Any, List[int]]:
        ret = defaultdict(list)
        for state_num, state in self.shield_spec.items():
            if state.initial_state:
                ret[state.label].append(state_num)

        return ret

    def get_actual_state_from_candidates(self, possible_states, label):
        for state_num in possible_states:
            if self.shield_spec[state_num].label == label:
                return state_num

        raise Exception("The shield did not accurately model the environment")

    def evaluate_joint_action(self, state, joint_obs, proposed_action, shield_state: T) -> Tuple[ShieldResult, T]:
        label = self.label_extractor(state)

        if shield_state == -1:
            possible_shield_states = self._get_label_to_initial_shield_state_dict()[label]
            if len(possible_shield_states) != 1:
                raise Exception("Shield has more than one possible start state for the given label. "
                                "Override get_initial_shield_state to specify the desired behavior")

            current_shield_automaton_state_num = possible_shield_states[0]
        else:
            prev_shield_state_num, prev_agent_order_idx = shield_state
            possible_shield_states = self.shield_spec[prev_shield_state_num].action_permutations[
                prev_agent_order_idx].next_states
            current_shield_automaton_state_num = self.get_actual_state_from_candidates(possible_shield_states, label)

        current_shield_automaton_state = self.shield_spec[current_shield_automaton_state_num]
        if self.random_agent_order:
            agent_permutation_order = np.random.randint(len(current_shield_automaton_state.action_permutations))
        else:
            agent_permutation_order = 0

        allowed_indiv_actions_all_agents = current_shield_automaton_state.action_permutations[
            agent_permutation_order].actions

        ret = []

        for proposed_indiv_action, allowed_indiv_actions in zip(proposed_action, allowed_indiv_actions_all_agents):
            if proposed_indiv_action in allowed_indiv_actions:
                ret.append(AgentResult(AgentUpdate(action=proposed_indiv_action)))
            elif len(allowed_indiv_actions) > 0:
                ret.append(self.replace_action_agent_result(proposed_indiv_action, allowed_indiv_actions[0]))

        return ret, (current_shield_automaton_state_num, agent_permutation_order)
